import itertools
import time
from collections import *

import numpy as np
import torch as th

import dgl

Graph = namedtuple(
    "Graph",
    [
        "g",
        "src",
        "tgt",
        "tgt_y",
        "nids",
        "eids",
        "nid_arr",
        "n_nodes",
        "n_edges",
        "n_tokens",
    ],
)


class GraphPool:
    "Create a graph pool in advance to accelerate graph building phase in Transformer."

    def __init__(self, n=50, m=50):
        """
        args:
            n: maximum length of input sequence.
            m: maximum length of output sequence.
        """
        print("start creating graph pool...")
        tic = time.time()
        self.n, self.m = n, m
        g_pool = [[dgl.graph(([], [])) for _ in range(m)] for _ in range(n)]
        num_edges = {
            "ee": np.zeros((n, n)).astype(int),
            "ed": np.zeros((n, m)).astype(int),
            "dd": np.zeros((m, m)).astype(int),
        }
        for i, j in itertools.product(range(n), range(m)):
            src_length = i + 1
            tgt_length = j + 1

            g_pool[i][j].add_nodes(src_length + tgt_length)
            enc_nodes = th.arange(src_length, dtype=th.long)
            dec_nodes = th.arange(tgt_length, dtype=th.long) + src_length

            # enc -> enc
            us = enc_nodes.unsqueeze(-1).repeat(1, src_length).view(-1)
            vs = enc_nodes.repeat(src_length)
            g_pool[i][j].add_edges(us, vs)
            num_edges["ee"][i][j] = len(us)
            # enc -> dec
            us = enc_nodes.unsqueeze(-1).repeat(1, tgt_length).view(-1)
            vs = dec_nodes.repeat(src_length)
            g_pool[i][j].add_edges(us, vs)
            num_edges["ed"][i][j] = len(us)
            # dec -> dec
            indices = th.triu(th.ones(tgt_length, tgt_length)) == 1
            us = dec_nodes.unsqueeze(-1).repeat(1, tgt_length)[indices]
            vs = dec_nodes.unsqueeze(0).repeat(tgt_length, 1)[indices]
            g_pool[i][j].add_edges(us, vs)
            num_edges["dd"][i][j] = len(us)

        print(
            "successfully created graph pool, time: {0:0.3f}s".format(
                time.time() - tic
            )
        )
        self.g_pool = g_pool
        self.num_edges = num_edges

    def beam(self, src_buf, start_sym, max_len, k, device="cpu"):
        """
        Return a batched graph for beam search during inference of Transformer.
        args:
            src_buf: a list of input sequence
            start_sym: the index of start-of-sequence symbol
            max_len: maximum length for decoding
            k: beam size
            device: 'cpu' or 'cuda:*'
        """
        g_list = []
        src_lens = [len(_) for _ in src_buf]
        tgt_lens = [max_len] * len(src_buf)
        num_edges = {"ee": [], "ed": [], "dd": []}
        for src_len, tgt_len in zip(src_lens, tgt_lens):
            i, j = src_len - 1, tgt_len - 1
            for _ in range(k):
                g_list.append(self.g_pool[i][j])
            for key in ["ee", "ed", "dd"]:
                num_edges[key].append(int(self.num_edges[key][i][j]))

        g = dgl.batch(g_list)
        src, tgt = [], []
        src_pos, tgt_pos = [], []
        enc_ids, dec_ids = [], []
        e2e_eids, e2d_eids, d2d_eids = [], [], []
        n_nodes, n_edges, n_tokens = 0, 0, 0
        for src_sample, n, n_ee, n_ed, n_dd in zip(
            src_buf, src_lens, num_edges["ee"], num_edges["ed"], num_edges["dd"]
        ):
            for _ in range(k):
                src.append(th.tensor(src_sample, dtype=th.long, device=device))
                src_pos.append(th.arange(n, dtype=th.long, device=device))
                enc_ids.append(
                    th.arange(
                        n_nodes, n_nodes + n, dtype=th.long, device=device
                    )
                )
                n_nodes += n
                e2e_eids.append(
                    th.arange(
                        n_edges, n_edges + n_ee, dtype=th.long, device=device
                    )
                )
                n_edges += n_ee
                tgt_seq = th.zeros(max_len, dtype=th.long, device=device)
                tgt_seq[0] = start_sym
                tgt.append(tgt_seq)
                tgt_pos.append(th.arange(max_len, dtype=th.long, device=device))

                dec_ids.append(
                    th.arange(
                        n_nodes, n_nodes + max_len, dtype=th.long, device=device
                    )
                )
                n_nodes += max_len
                e2d_eids.append(
                    th.arange(
                        n_edges, n_edges + n_ed, dtype=th.long, device=device
                    )
                )
                n_edges += n_ed
                d2d_eids.append(
                    th.arange(
                        n_edges, n_edges + n_dd, dtype=th.long, device=device
                    )
                )
                n_edges += n_dd

        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
        g = g.to(device).long()

        return Graph(
            g=g,
            src=(th.cat(src), th.cat(src_pos)),
            tgt=(th.cat(tgt), th.cat(tgt_pos)),
            tgt_y=None,
            nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
            eids={
                "ee": th.cat(e2e_eids),
                "ed": th.cat(e2d_eids),
                "dd": th.cat(d2d_eids),
            },
            nid_arr={"enc": enc_ids, "dec": dec_ids},
            n_nodes=n_nodes,
            n_edges=n_edges,
            n_tokens=n_tokens,
        )

    def __call__(self, src_buf, tgt_buf, device="cpu"):
        """
        Return a batched graph for the training phase of Transformer.
        args:
            src_buf: a set of input sequence arrays.
            tgt_buf: a set of output sequence arrays.
            device: 'cpu' or 'cuda:*'
        """
        g_list = []
        src_lens = [len(_) for _ in src_buf]
        tgt_lens = [len(_) - 1 for _ in tgt_buf]
        num_edges = {"ee": [], "ed": [], "dd": []}
        for src_len, tgt_len in zip(src_lens, tgt_lens):
            i, j = src_len - 1, tgt_len - 1
            g_list.append(self.g_pool[i][j])
            for key in ["ee", "ed", "dd"]:
                num_edges[key].append(int(self.num_edges[key][i][j]))

        g = dgl.batch(g_list)
        src, tgt, tgt_y = [], [], []
        src_pos, tgt_pos = [], []
        enc_ids, dec_ids = [], []
        e2e_eids, d2d_eids, e2d_eids = [], [], []
        n_nodes, n_edges, n_tokens = 0, 0, 0
        for src_sample, tgt_sample, n, m, n_ee, n_ed, n_dd in zip(
            src_buf,
            tgt_buf,
            src_lens,
            tgt_lens,
            num_edges["ee"],
            num_edges["ed"],
            num_edges["dd"],
        ):
            src.append(th.tensor(src_sample, dtype=th.long, device=device))
            tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long, device=device))
            tgt_y.append(
                th.tensor(tgt_sample[1:], dtype=th.long, device=device)
            )
            src_pos.append(th.arange(n, dtype=th.long, device=device))
            tgt_pos.append(th.arange(m, dtype=th.long, device=device))
            enc_ids.append(
                th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)
            )
            n_nodes += n
            dec_ids.append(
                th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device)
            )
            n_nodes += m
            e2e_eids.append(
                th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)
            )
            n_edges += n_ee
            e2d_eids.append(
                th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)
            )
            n_edges += n_ed
            d2d_eids.append(
                th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)
            )
            n_edges += n_dd
            n_tokens += m

        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
        g = g.to(device).long()

        return Graph(
            g=g,
            src=(th.cat(src), th.cat(src_pos)),
            tgt=(th.cat(tgt), th.cat(tgt_pos)),
            tgt_y=th.cat(tgt_y),
            nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
            eids={
                "ee": th.cat(e2e_eids),
                "ed": th.cat(e2d_eids),
                "dd": th.cat(d2d_eids),
            },
            nid_arr={"enc": enc_ids, "dec": dec_ids},
            n_nodes=n_nodes,
            n_edges=n_edges,
            n_tokens=n_tokens,
        )
